{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ac7b59ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils import data\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f69ebc0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2\n",
    "features, labels = d2l.synthetic_data(true_w, true_b, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1081d6ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[ 0.3551,  1.7673],\n",
       "         [-1.2251, -1.6617],\n",
       "         [ 0.6565, -0.1224],\n",
       "         [-1.6279, -1.2256],\n",
       "         [-0.5363,  1.3034],\n",
       "         [-2.3440,  1.3905],\n",
       "         [-1.6693,  0.1119],\n",
       "         [ 1.7608,  0.9719],\n",
       "         [-1.0507,  1.1723],\n",
       "         [-0.2266,  1.4051]]),\n",
       " tensor([[-1.1001],\n",
       "         [ 7.3931],\n",
       "         [ 5.9389],\n",
       "         [ 5.1135],\n",
       "         [-1.3003],\n",
       "         [-5.2101],\n",
       "         [ 0.4818],\n",
       "         [ 4.4260],\n",
       "         [-1.9013],\n",
       "         [-1.0406]])]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_array(data_arrays, batch_size, is_train=True):\n",
    "    \"\"\"构造一个Pytorch数据迭代器\"\"\"\n",
    "    dataset = data.TensorDataset(*data_arrays) # * 表示传入的参数args是一个元组，** 表示传入的参数args是一个字典\n",
    "    return data.DataLoader(dataset, batch_size, shuffle=is_train) # 这是一个数据加载器，用于加载数据并生成小批次供模型训练。\n",
    "\n",
    "batch_size = 10\n",
    "data_iter = load_array((features, labels), batch_size)\n",
    "\n",
    "next(iter(data_iter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ccb6d406",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "11efcc31",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(nn.Linear(2, 1)) # (2,1)为输入输出维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9ce2f00a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data.normal_(0, 0.01) # 均值为0，标准差为0.01\n",
    "net[0].bias.data.fill_(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "905b5ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8821e36e",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = torch.optim.SGD(net.parameters(), lr=0.03) #随机梯度下降算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3058b460",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.000218\n",
      "epoch 2, loss 0.000103\n",
      "epoch 3, loss 0.000104\n"
     ]
    }
   ],
   "source": [
    "num_epoch = 3\n",
    "for epoch in range(num_epoch):\n",
    "    for X, y in data_iter:\n",
    "        '''\n",
    "        在代码中，net(X) 返回了模型对输入数据 X 的预测结果，\n",
    "        这是一个张量（tensor），通常是一个包含预测值的向量或矩阵，\n",
    "        取决于模型的输出维度和小批次的大小。然后，这个预测结果与真实目标 y 一起传递给损失函数 loss，\n",
    "        并由 loss 计算出一个标量值 l，表示整个小批次的损失。\n",
    "        '''\n",
    "        l = loss(net(X), y) # net(X) 表示将输入数据 X 通过模型前向传播，得到模型的预测结果\n",
    "        trainer.zero_grad()\n",
    "        l.backward()\n",
    "        trainer.step() # 来更新模型的参数\n",
    "    l = loss(net(features), labels)\n",
    "    print(f'epoch {epoch + 1}, loss {l:f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "d2l",
   "language": "python",
   "name": "d2l"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
